/**
 * 
 */
package edu.umd.clip.lm.tests;

import java.io.*;
import java.nio.channels.FileChannel;

import edu.berkeley.nlp.util.*;
import edu.umd.clip.lm.model.*;
import edu.umd.clip.lm.model.data.ContextFuturesPair;
import edu.umd.clip.lm.model.data.OnDiskTrainingDataReader;
import edu.umd.clip.lm.model.data.ReadableTrainingData;
import edu.umd.clip.lm.model.data.TrainingDataBlock;
import edu.umd.clip.lm.model.data.TrainingDataReader;
import edu.umd.clip.lm.model.data.TupleCountPair;
import edu.umd.clip.lm.questions.*;
import edu.umd.clip.lm.util.tree.*;
import edu.umd.clip.lm.util.*;

/**
 * @author Denis Filimonov <den@cs.umd.edu>
 *
 */
public class CountParameters {
	public static class Options {
        @Option(name = "-config", usage = "XML config file")
		public String config;
        @Option(name = "-input", required = false, usage = "Training data file (Default: stdin)")
		public String input;
        @Option(name = "-forest", required = false, usage = "LM ID to train (default: " + LanguageModel.PRIMARY_LM_ID + ")")
		public String lm = LanguageModel.PRIMARY_LM_ID;        
	}

	/**
	 * @param args
	 * @throws ClassNotFoundException 
	 */
	public static void main(String[] args) throws ClassNotFoundException {
        OptionParser optParser = new OptionParser(Options.class);
        Options opts = (Options) optParser.parse(args, true);

		Experiment.initialize(opts.config);
		final Experiment experiment = Experiment.getInstance();
		experiment.buildPrefixes();
		experiment.buildWordPrefixes();

		LanguageModel lm  = experiment.getLM(opts.lm);
		
		final BinaryTree<HistoryTreePayload> tree = lm.getHistoryTree();

		int nrClusters;
		int totalClusterCount;
		Long2IntMap distributions[];
		
		try {
			FileChannel channel = new FileInputStream(opts.input).getChannel();
			TrainingDataReader reader = new OnDiskTrainingDataReader(channel);
			ReadableTrainingData data = new ReadableTrainingData(reader); 

			{
				int clusterid = 0;
				BinaryTreeIterator<HistoryTreePayload> iterator = tree.getPostOrderIterator();
				while(iterator.hasNext()) {
					iterator.next().clusterid = ++clusterid;
				}
				nrClusters = (clusterid + 1) / 2; // number of leaves
				totalClusterCount = clusterid + 2;
			}
			distributions = new Long2IntMap[totalClusterCount];
			for(int i=0; i<totalClusterCount; ++i) {
				distributions[i] = new Long2IntMap(5);
			}
			
			while(data.hasNext()) {
				TrainingDataBlock block = data.next();
				for(ContextFuturesPair pair : block) {
					BinaryTree<HistoryTreePayload> node = tree;
					while(!node.isLeaf()) {
						if (node.getPayload().question.test(pair.getContext())) {
							node = node.getRight();
						} else {
							node = node.getLeft();
						}
					}
					Long2IntMap counts = distributions[node.getPayload().clusterid];
					for(TupleCountPair tc : pair.getFutures()) {
						counts.addAndGet(tc.tuple, tc.count);
					}
				}
			}
			int leafParameters = 0;
			for(int i=0; i<totalClusterCount; ++i) {
				leafParameters += distributions[i].size();
			}
			
			leafParameters *= 3; // 3 for (word,tag,count)
			
			int lambdaParameters = totalClusterCount;
			int questionParameters = 0;
			{
				BinaryTreeIterator<HistoryTreePayload> iterator = tree.getPostOrderIterator();
				while(iterator.hasNext()) {
					HistoryTreePayload payload = iterator.next();
					if (payload != null && payload.question != null) {
						if (payload.question.getQuestionType() == Question.IN_SET_QUESTION) {
							InSetQuestion q = (InSetQuestion) payload.question;
							questionParameters += q.getValues().length;
						} else {
							++questionParameters;
						}
					}
				}
			}
			
			System.out.printf("Leaf parameters: %d\n", leafParameters);
			System.out.printf("Question Parameters: %d\n", questionParameters);
			System.out.printf("Lambda parameters: %d\n", lambdaParameters);
			System.out.printf("Total: %d\n", leafParameters + questionParameters + lambdaParameters);
		} catch(IOException e) {
			e.printStackTrace();
		}
	}

}
